[Feature][Blackwell] Add SM120 T.float4_e2m1fn FP4 GEMM support.#2171
[Feature][Blackwell] Add SM120 T.float4_e2m1fn FP4 GEMM support.#2171TerminusAkivili wants to merge 4 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR implements SM120 (CUDA 12.0+) FP4 (float4_e2m1fn) GEMM support across TileLang: examples and host unpacking, CUDA/TI codegen FP4 storage/indexing/vector/scalar handling, FP4-aware cp.async injection, b4x16 ldmatrix helpers, CuTe SM120 MMA dispatch for FP4/mixed operands, layout/macro generation changes, and GemmMMA integration. ChangesSM120 FP4 GEMM Support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (1)
src/tl_templates/cuda/cuda_fp4.h (1)
166-187: ⚡ Quick winVerify register allocation for
fp4_e2_t values[64]in device code.The 64-element local array is constant-indexed throughout (
values[0]–values[63]), so nvcc at-O2+should scalar-replace it into registers. However, unlike the explicitly-parameterizedmake_fp4_e2_32_twhich guarantees register-only arguments, register spilling to local memory is possible at lower optimisation levels or with larger surrounding register pressure. Consider adding a__forceinline__annotation to maximise inlining and scalar replacement at call sites.Proposed annotation
-template <typename... Args> -TL_DEVICE fp4_e2_64_t make_fp4_e2_64_t(Args... args) { +template <typename... Args> +TL_DEVICE __forceinline__ fp4_e2_64_t make_fp4_e2_64_t(Args... args) {🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@src/tl_templates/cuda/cuda_fp4.h` around lines 166 - 187, The local array fp4_e2_t values[64] in make_fp4_e2_64_t may be spilled under some compile conditions; annotate the function to force inlining (e.g., add a __forceinline__/always-inline device inline attribute to make_fp4_e2_64_t) so nvcc can scalar-replace values[0]..values[63] into registers and inline the make_fp4_e2_32_t calls; update the function declaration for make_fp4_e2_64_t accordingly (keeping fp4_e2_t values[64] and the existing make_fp4_e2_32_t usages unchanged).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/backend/cuda/codegen/codegen_cuda.cc`:
- Around line 1973-2003: The FP4 padded shared-memory vector path
(IsFp4PaddedSharedStorage + code using GetFp4PaddedSharedIndex and the
byte_offset lambda when constructing the reinterpret cast for t.lanes()) can
incorrectly span the padded 16-element row boundary; add a guard or split logic:
either assert the logical base alignment (e.g., Ensure base % 16 == 0 for the
requested load/store) or detect when the access crosses a 16-element row by
computing the start and end logical indices (base + offset and base + offset +
t.lanes()-1) and comparing their 16-element row indices (truncdiv(..., 16)); if
it crosses, split the operation into two row-aligned fragments (like the
existing t.lanes()==32 two-fragment approach) and merge them, otherwise keep the
current single contiguous byte reinterpretation; apply the same fix to the other
similar blocks identified (around the other ranges mentioned).
- Around line 4428-4444: The allocator treats only scope == "local" as the path
that emits local backing arrays but FP4 fragments use the semantic storage name
"local.fragment", so allocations for these still hit the unsupported-scope
branch; update the scope checks used around is_int4_scalar_local, the FP4
alignas(16) branch, and the place that prints/omits the storage scope to treat
"local.fragment" as equivalent to "local" (either normalize scope to "local"
earlier or change conditions from scope == "local" to (scope == "local" || scope
== "local.fragment")), ensuring PrintStorageScope/PrintType and the
backing-array emission path handle FP4 fragments the same as regular local
allocations (references: is_int4_scalar_local, op->dtype.is_float4_e2m1fn(),
PrintStorageScope, PrintType, and the "local.fragment" semantic storage).
In `@tilelang/cuda/intrinsics/macro/mma_macro_generator.py`:
- Around line 121-124: The FP4 fast-path in mma_macro_generator.py sets
self.k_dim = 32 without respecting self.chunk, causing micro_size_k to exceed
chunk when chunk < 32; update the FP4 branch in the initializer (the block
setting self.k_dim) to clamp k_dim by self.chunk (e.g., self.k_dim = min(32,
self.chunk)) and add the same clamp/guard in the subclass override (the code
around lines 873–877) so both places respect chunk; optionally emit a clear
ValueError or assertion if chunk < required minimum to fail early with a helpful
message referencing the dtype and chunk size.
---
Nitpick comments:
In `@src/tl_templates/cuda/cuda_fp4.h`:
- Around line 166-187: The local array fp4_e2_t values[64] in make_fp4_e2_64_t
may be spilled under some compile conditions; annotate the function to force
inlining (e.g., add a __forceinline__/always-inline device inline attribute to
make_fp4_e2_64_t) so nvcc can scalar-replace values[0]..values[63] into
registers and inline the make_fp4_e2_32_t calls; update the function declaration
for make_fp4_e2_64_t accordingly (keeping fp4_e2_t values[64] and the existing
make_fp4_e2_32_t usages unchanged).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: a09f3145-ce2d-4b0d-bb75-d916a099b2be
📒 Files selected for processing (16)
examples/gemm_fp4/example_gemm_a8w4_sm120.pyexamples/gemm_fp4/example_gemm_fp4_sm120.pysrc/backend/cuda/codegen/codegen_cuda.ccsrc/backend/cuda/codegen/codegen_cuda.hsrc/backend/cuda/op/copy.ccsrc/backend/cuda/op/copy_analysis.ccsrc/tl_templates/cuda/cuda_fp4.hsrc/tl_templates/cuda/gemm_mma.hsrc/tl_templates/cuda/instruction/mma.hsrc/tl_templates/cuda/ldsm.hsrc/transform/lower_ptx_async_copy.ccsrc/transform/ptx_async_copy_injector.htilelang/cuda/intrinsics/layout/mma_layout.pytilelang/cuda/intrinsics/layout/utils.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pytilelang/cuda/op/gemm/gemm_mma.py
3e5823d to
7f254a9
Compare
|
Hi @LeiWang1999, no rush at all. Feel free to check it whenever it's convenient for you. I'd love your feedback. Thank you! |
Lower FP4 packed vector load/store with odd or symbolic bases to per-lane nibble operations to avoid silent miscompiles. Reject T.gemm K tiles that are not divisible by the MMA instruction K tile so FP4/A8W4 block_K tails cannot be silently skipped.
cb5bf3d to
795cb39
Compare
Summary
This PR adds SM120 fragment-MMA GEMM support for
T.float4_e2m1fn.It covers plain FP4 GEMM and explicit mixed FP8/FP4 GEMM while keeping
the TileLang-facing API dtype-semantic.
Kernels continue to declare FP4 operands as
T.float4_e2m1fn. Packedbyte storage is handled by lowering/codegen and by host-side example setup;
users do not have to model FP4 GEMM operands as
uint8tensors in TileLangprograms.
Supported GEMM combinations:
T.float4_e2m1fnT.float4_e2m1fnT.float32T.float8_e4m3fnT.float4_e2m1fnT.float32T.float4_e2m1fnT.float8_e4m3fnT.float32Design Goals
T.float4_e2m1fnat the language level.
workaround.
ldmatrix behavior.
GEMM.
kernels that silently skip data.
Hardware Contracts
The SM120 FP4 path has three contracts that must line up across lowering,
layout, and template dispatch:
b4x16_p64ldmatrixsrc/tl_templates/cuda/ldsm.h,src/backend/cuda/codegen/codegen_cuda.cc,src/backend/cuda/op/copy.cc,tilelang/cuda/intrinsics/layout/utils.py,tilelang/cuda/intrinsics/layout/mma_layout.pysrc/backend/cuda/codegen/codegen_cuda.cc,src/backend/cuda/codegen/codegen_cuda.hm16n8k32MMA for explicit FP4/FP8 dtype pairssrc/tl_templates/cuda/instruction/mma.h,src/tl_templates/cuda/gemm_mma.h,tilelang/cuda/intrinsics/macro/mma_macro_generator.py,tilelang/cuda/op/gemm/gemm_mma.pyThe key implementation detail is that SM120
b4x16_p64consumes packed FP4bytes from shared memory with a padded shared row layout. Global memory,
shared memory, and local fragments therefore cannot all use the same offset
model:
the MMA lowering path.
Implementation Details
Semantic FP4 Storage Model
This PR separates "what the TileLang program declares" from "how the FP4
payload is physically moved":
T.float4_e2m1fn.detail.
_packedaliases.This avoids leaking
uint8into the public GEMM dtype model while stillallowing the generated CUDA path to use the byte/nibble representation that
the hardware requires.
SM120 FP4 LDSM
src/tl_templates/cuda/ldsm.hadds SM120ptx_ldmatrix_b4x16_x{1,2,4}helpers guarded for the target architecture. CUDA lowering selects these
helpers when the source fragment is explicitly
float4_e2m1fn.The Python layout side adds FP4-specific ldmatrix logical layouts and offset
handling. This is gated on
float4_e2m1fn, so the existing int4/uint4ldmatrix offset behavior stays on the previous path.
SM120 MMA Dispatch
The template dispatch adds SM120
cute::SM120_16x8x32_TNsupport for:The FP4 operand register payload adjustment is applied only to FP4 operands
before calling the CuTe atom. Mixed A8W4/W4A8 dispatch is selected from the
explicit A/B dtype pair, rather than inferred from a packed integer carrier.
Copy And Async Copy
The copy path distinguishes packed global FP4 offsets from padded shared FP4
offsets. Shared-to-fragment copy lowering routes SM120 FP4 through the new
b4x16_p64ldmatrix path, while global-to-shared async copy keeps using theexisting cp.async lowering with FP4 padded-shared metadata enabled.
For FP4 global-to-shared async copy, the lowering emits 8-byte segments for
packed FP4 storage and carries the extra metadata needed to place those bytes
into the padded shared-memory layout.
K Tile Validation
SM120 FP4/A8W4 MMA consumes K in instruction-sized
m16n8k32tiles. AT.gemmblock K that is not divisible by the selected instruction K tilecannot be represented by the current lowering without dropping the leftover
K range.
This PR therefore rejects unsupported K tile choices up front. For example,
block_K=48is invalid for this path because it would execute one K=32 MMAtile and silently omit the remaining K=16 tail.
Main Changes
CUDA Templates
src/tl_templates/cuda/ldsm.hptx_ldmatrix_b4x16_x{1,2,4}helpers with architecture guard.src/tl_templates/cuda/instruction/mma.hcute::SM120_16x8x32_TNdispatch for FP4xFP4, FP8xFP4, and FP4xFP8 to FP32.src/tl_templates/cuda/instruction/mma.hsrc/tl_templates/cuda/gemm_mma.hsrc/tl_templates/cuda/cuda_fp4.hCUDA Lowering
src/backend/cuda/codegen/codegen_cuda.ccptx_ldmatrix_b4x16_x{1,2,4}for explicitfloat4_e2m1fnldmatrix loads.src/backend/cuda/codegen/codegen_cuda.ccsrc/backend/cuda/codegen/codegen_cuda.cc_packedaliases.src/backend/cuda/codegen/codegen_cuda.hsrc/backend/cuda/op/copy.ccsrc/backend/cuda/op/copy.ccsrc/backend/cuda/op/copy_analysis.ccsrc/transform/lower_ptx_async_copy.ccsrc/transform/ptx_async_copy_injector.hPython Lowering
tilelang/cuda/intrinsics/layout/utils.pyfloat4_e2m1fn.tilelang/cuda/intrinsics/layout/mma_layout.pytilelang/cuda/intrinsics/macro/mma_macro_generator.pym16n8k32MMA granularity.tilelang/cuda/op/gemm/gemm_mma.pytilelang/cuda/op/gemm/gemm_mma.pyExamples
examples/gemm_fp4/example_gemm_fp4_sm120.pyexamples/gemm_fp4/example_gemm_a8w4_sm120.pyChanges Added After Initial Review
src/backend/cuda/codegen/codegen_cuda.cctilelang/cuda/op/gemm/gemm_mma.pyT.gemmK tiles that are not divisible by the selected MMA instruction K tile, preventing FP4/A8W4 cases such asblock_K=48from silently skipping the K tail.Why The Review Fixes Matter
FP4 global packed storage is byte-addressed, but logical FP4 elements are
nibble-addressed. A vector reinterpret load/store is only safe when the
logical base offset is known to be even. If the offset is odd, or if codegen
cannot prove that it is even, vectorized byte reinterpretation can read or
write the wrong nibble without producing a compilation error.
Likewise, SM120 FP4/A8W4 MMA consumes K in fixed instruction-sized chunks.
Allowing a
block_Ksuch as 48 would make the generated kernel execute onlythe representable K=32 portion and miss the K tail. The new validation turns
that silent numerical error into an explicit unsupported-shape error.
Validation
Local SM120 validation used an RTX 5090 / compute capability 12.0 environment.
Build and examples:
Observed numerical results:
Generated CUDA was inspected for the expected SM120 FP4 markers:
Focused tests:
Observed focused-test results:
Notes And Non-Goals
float4_e2m1fn.Status
All set now!